/**
 * \file dnn/src/cuda/cumsum/kern_impl.cuinl
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./kern.cuh"
#include "./kern_helper.cuh"
#include "megdnn/dtype.h"
#include "src/cuda/cub/device/device_scan.cuh"
#include "src/cuda/cub/util_ptx.cuh"

namespace megdnn {
namespace cuda {
namespace cumsum {
namespace detail {

/**
  * src shape is (A, B, C), performing blockwise scan over B axis.
  * Each CUDA block calculates a blockwise scan result of size (BY2, BX).
  * The block area corresponds to a 2-D area on (B, C) dimension of src.
  *
  * Per-block prefix sum is stored in dst (dst has the same shape as src).
  *
  * The whole scan result of each block as a single value is stored in
  * block_sum (of shape (A, B/BY2, C)).
  *
  * block_sum can be NULL.
  *
  * src and dst can be inplace.
  *
  * We need to launch (C/BX)*ceil(B/BY2)*A blocks in total.
  * Because in CUDA the number of launched blocks over y and z axis are
  * limited (at most 65535), we launch all blocks over axis x.
  *
  * Param: exclusive
  *  This flag specifies whether the scan is inclusive or exclusive, namely
  *  whether src_i influences dst_i.
  *
  * Param: reverse:
  *  This flag specifies whether the scan is forward or backward.
  *
  * Example:
  *  !exclusive && !reverse: dst_i = op(src_0, src_1, ..., src_i)
  *  !exclusive && reverse: dst_i = op(src_i, src_{i+1}, ..., src_{n-1})
  *  exclusive && !reverse: dst_i = op(src_0, src_1, ..., src{i-1})
  *  exclusive && reverse: dst_i = op(src_{i+1}, src{i+2}, ..., src{n-1})
  *
  * Op should have the following methods:
  *  static T init()
  *  static T apply(T lhs, T rhs)
  */
template <typename T, typename Op, bool exclusive, bool reverse,
         uint32_t BY, uint32_t BX>
__global__ void scan_kernel(T *dst, T *block_sum,
        uint32_t A, uint32_t B, uint32_t C, const Op op) {
    constexpr size_t warp_size = 32;
    const uint32_t BY2 = BY*2;
    const uint32_t B_ = (B+BY2-1) / BY2;
    const uint32_t C_ = (C+BX-1) / BX;
    const uint32_t GX = C_;
    const uint32_t GY = B_;
    // src, dst: (A, B, C)
    // block_sum: (A, B_, C)
    // shared: (BY2+1, BX)
    const uint32_t bx = blockIdx.x % GX;
    const uint32_t by = blockIdx.x / GX % GY;
    const uint32_t bz = blockIdx.x / GX / GY;
    const uint32_t tx = threadIdx.x;
    const uint32_t ty = threadIdx.y;
    // TODO: shared memory bank conflict optimization
#define shared_idx(x) ((x) + ((x) >> 5))
    volatile __shared__ T cache[shared_idx((BY2+1)*BX)];
    uint32_t base_offset = (bz)*B*C + (by*BY2)*C + (bx*BX);
    dst += base_offset;
    // load to cache
    if (reverse) {
        cache[shared_idx((BY2-ty)*BX+tx)] = ty+by*BY2 < B && tx+bx*BX < C ?
            op.visit(base_offset + ty*C + tx) : Op::init();
    } else {
        cache[shared_idx((ty+1)*BX+tx)] = ty+by*BY2 < B && tx+bx*BX < C ?
            op.visit(base_offset + ty*C + tx) : Op::init();
    }
    if (reverse) {
        cache[shared_idx((BY-ty)*BX+tx)] =
            (ty+BY) + by*BY2 < B && tx+bx*BX < C ?
            op.visit(base_offset + (ty+BY)*C + tx) : Op::init();
    } else {
        cache[shared_idx((ty+BY+1)*BX+tx)] =
            (ty+BY) + by*BY2 < B && tx+bx*BX < C ?
            op.visit(base_offset + (ty+BY)*C + tx) : Op::init();
    }
    if (ty == 0) {
        cache[shared_idx(tx)] = Op::init();
    }
    __syncthreads();
    uint32_t total, stride;
    // first pass
#pragma unroll
    for (total = BY, stride = 1;
            total > 0;
            total >>= 1, stride <<= 1)
    {
        if (ty < total) {
            uint32_t ai = shared_idx(stride * (2*ty+1) * BX + tx);
            uint32_t bi = shared_idx(stride * (2*ty+2) * BX + tx);
            cache[bi] = Op::apply(cache[bi], cache[ai]);
        }
        if (total > warp_size/BX) __syncthreads();
        else cub::WARP_SYNC(0xffffffff);
    }
    // second pass
#pragma unroll
    for (total = 1, stride = BY;
            stride > 0;
            total <<= 1, stride >>= 1)
    {
        if (total > warp_size/BX) __syncthreads();
        else cub::WARP_SYNC(0xffffffff);
        if (ty < total) {
            uint32_t ai = shared_idx(stride * (2*ty+0) * BX + tx);
            uint32_t bi = shared_idx(stride * (2*ty+1) * BX + tx);
            cache[bi] = Op::apply(cache[bi], cache[ai]);
        }
    }
    __syncthreads();
    uint32_t ty_offset = (exclusive ? 0 : 1);
    if (ty+by*BY2 < B && tx+bx*BX < C) {
        if (reverse) {
            dst[ty*C + tx] = cache[shared_idx((BY2-1-ty+ty_offset)*BX + tx)];
        } else {
            dst[ty*C + tx] = cache[shared_idx((ty+ty_offset)*BX + tx)];
        }
    }
    if (ty+BY+by*BY2 < B && tx+bx*BX < C) {
        if (reverse) {
            dst[(ty+BY)*C + tx] =
                cache[shared_idx((BY2-1-(ty+BY)+ty_offset)*BX + tx)];
        } else {
            dst[(ty+BY)*C + tx] =
                cache[shared_idx((ty+BY+ty_offset)*BX + tx)];
        }
    }
    if (block_sum && ty == 0 && bx*BX+tx < C) {
        block_sum[(bz)*B_*C + (by)*C + (bx*BX) + tx] =
            cache[shared_idx(BY2*BX + tx)];
    }
}

template <typename T, typename Op, uint32_t BY, uint32_t BX>
__global__ void update_kernel(T *dst, const T *delta,
        uint32_t A, uint32_t B, uint32_t C) {
    const uint32_t BY2 = BY*2;
    const uint32_t B_ = (B+BY2-1) / BY2;
    const uint32_t C_ = (C+BX-1) / BX;
    const uint32_t GX = C_;
    const uint32_t GY = B_;
    // src: (A, B, C)
    // delta: (A, B_, C)
    const uint32_t bx = blockIdx.x % GX;
    const uint32_t by = blockIdx.x / GX % GY;
    const uint32_t bz = blockIdx.x / GX / GY;
    const uint32_t tx = threadIdx.x;
    const uint32_t ty = threadIdx.y;

    if (tx + bx*BX < C) {
        T delta_v = delta[(bz)*B_*C + (by)*C + (bx*BX) + tx];
        if (ty+by*BY2 < B && tx+bx*BX < C) {
            T &res = dst[bz*B*C + (ty+by*BY2)*C + (tx+bx*BX)];
            res = Op::apply(res, delta_v);
        }
        if (ty+BY+by*BY2 < B && tx+bx*BX < C) {
            T &res = dst[bz*B*C + (ty+BY+by*BY2)*C + (tx+bx*BX)];
            res = Op::apply(res, delta_v);
        }
    }
}

template <typename T, typename Op, bool exclusive, bool reverse>
void run_kern_multiAC(T* dst, T* workspace, uint32_t A, uint32_t B,
                      uint32_t C, const Op& op, cudaStream_t stream);

template <typename T, typename Op, bool exclusive, bool reverse,
         uint32_t BX, uint32_t BY>
void do_run_kern(T *dst, T *workspace,
        uint32_t A, uint32_t B, uint32_t C, const Op &op, cudaStream_t stream) {
    const uint32_t BY2 = BY*2;
    const uint32_t B_ = (B+BY2-1)/BY2;
    const uint32_t C_ = (C+BX-1)/BX;

    dim3 blocks(C_*B_*A);
    dim3 threads(BX, BY);

    scan_kernel<T, Op, exclusive, reverse, BY, BX>
        <<<blocks, threads, 0, stream>>>(
                dst, B > BY2 ? workspace : NULL, A, B, C, op);
    if (B <= BY2)
        return;

    run_kern_multiAC<T, typename Op::ContigOp, true, reverse>(
                workspace, workspace + A*B_*C, A, B_, C,
                Op::make_contig(workspace), stream);
    update_kernel<T, Op, BY, BX><<<blocks, threads, 0, stream>>>(
            dst, workspace, A, B, C);
}

template <typename T, typename Op, bool exclusive, bool reverse>
void run_kern_multiAC(T* dst, T* workspace, uint32_t A, uint32_t B, uint32_t C,
                      const Op& op, cudaStream_t stream) {
#define IF(BX, BY)                                                 \
    do {                                                           \
        if (vBX == BX && vBY == BY) {                              \
            return do_run_kern<T, Op, exclusive, reverse, BX, BY>( \
                    dst, workspace, A, B, C, op, stream);           \
        }                                                          \
    } while (0)

    uint32_t vBX, vBY;
    get_BX_BY(A, B, C, vBX, vBY);
    IF(1, 512);
    IF(2, 256);
    IF(4, 128);
    IF(8, 64);
    IF(16, 32);
    IF(32, 16);
    megdnn_trap();
#undef IF
}

//! wrap cub library for 1-dim scan
namespace cubwrap {

template <typename T, typename Op, bool reverse>
class InputIterator : public std::iterator<std::random_access_iterator_tag, T> {
    int m_offset, m_len;
    Op m_op;

public:
    InputIterator(Op op, int len) : m_offset(0), m_len(len), m_op(op) {}

    __device__ InputIterator(int offset, int len, Op op)
            : m_offset(offset), m_len(len), m_op(op) {}

    __device__ T operator[](int idx) {
        idx += m_offset;
        if (reverse) {
            idx = m_len - 1 - idx;
        }
        return m_op.visit(idx);
    }

    __device__ InputIterator operator+(int offset) {
        return InputIterator(m_offset + offset, m_len, m_op);
    }
};

template <typename T, bool reverse>
class OutputIterator
        : public std::iterator<std::random_access_iterator_tag, T> {
    int m_offset, m_len;
    T* m_dst;

public:
    OutputIterator(T* dst, int len) : m_offset(0), m_len(len), m_dst(dst) {}

    __device__ OutputIterator(int offset, int len, T* dst)
            : m_offset(offset), m_len(len), m_dst(dst) {}

    __device__ T& operator[](int idx) {
        idx += m_offset;
        if (reverse) {
            idx = m_len - 1 - idx;
        }
        return m_dst[idx];
    }

    __device__ OutputIterator operator+(int offset) {
        return OutputIterator(m_offset + offset, m_len, m_dst);
    }
};

template <typename T, typename Op>
struct ScanOp {
    __device__ __host__ T operator()(T a, T b) {
        // cub requires it to be a __device__ __host__ function but MegDNN has
        // no such contraint on Op::apply; so we just trap on host
#ifdef __CUDA_ARCH__
        return Op::apply(a, b);
#else
        megdnn_trap();
#endif
    }
};

template <typename T, typename Op, bool exclusive, bool reverse>
void invoke(T* dst, void* workspace, size_t wk_size, const Op& op, uint32_t len,
            cudaStream_t stream) {
    InputIterator<T, Op, reverse> inp_iter(op, len);
    OutputIterator<T, reverse> out_iter(dst, len);
    ScanOp<T, Op> scan_op;

    if (exclusive) {
        cuda_check(cub::DeviceScan::ExclusiveScan(workspace, wk_size, inp_iter,
                                                  out_iter, scan_op, Op::init(),
                                                  len, stream));
    } else {
        cuda_check(cub::DeviceScan::InclusiveScan(
                workspace, wk_size, inp_iter, out_iter, scan_op, len, stream));
    }
}
}  // namespace cubwrap

} // namespace detail

template <typename T, typename Op, bool exclusive, bool reverse>
void run_kern(T* dst, void* workspace, uint32_t workspace_size, uint32_t A,
              uint32_t B, uint32_t C, const Op& op, cudaStream_t stream) {
    if (A == 1 && C == 1) {
        return detail::cubwrap::invoke<T, Op, exclusive, reverse>(
                dst, workspace, workspace_size, op, B, stream);
    }

    return detail::run_kern_multiAC<T, Op, exclusive, reverse>(
            dst, static_cast<T*>(workspace), A, B, C, op, stream);
}

} // namespace cumsum
} // namespace cuda
} // namespace megdnn


// vim: ft=cuda syntax=cuda.doxygen
